{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Adult_noisy_label.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ds02bR9Kw7gG",
        "colab_type": "text"
      },
      "source": [
        "##### Copyright 2020 Google LLC. All Rights Reserved.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
        "\n",
        "> http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wXgeIg0ow9bh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import random\n",
        "from sklearn import linear_model as sklearn_linear_model\n",
        "from sklearn.model_selection import train_test_split\n",
        "import tensorflow.compat.v1 as tf"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QOkNrui2vJaQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tf.disable_eager_execution()"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "siRe26G2ypcY",
        "colab_type": "text"
      },
      "source": [
        "This colab contains TensorFlow code for implementing the surrogate projected gradient descent method presented in the paper:\n",
        "> Qijia Jiang, Olaoluwa Adigun, Harikrishna Narasimhan, Mahdi M. Fard, Maya Gupta, 'Optimizing Black-box Metrics with Adaptive Surrogates', ICML 2020. [[PDF]](https://arxiv.org/pdf/2002.08605.pdf)\n",
        "\n",
        "We consider the problem of learning with noisy training labels, given access to a small validation sample with the true labels. We seek to train a linear classifier that performs well on the G-mean metric. We apply the approach proposed in the paper to adaptively combine surrogates on the training set to best optimize the given metric on the validation sample. We will demostrate the effectiveness of this approach on the [UCI Adult dataset](https://archive.ics.uci.edu/ml/datasets/adult)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_exIz1ae5e7e"
      },
      "source": [
        "## Load dataset\n",
        "\n",
        "The data preprocessing code has been adapted from  <a href=\"https://github.com/google-research/tensorflow_constrained_optimization/blob/master/examples/jupyter/Fairness_adult.ipynb\">this colab</a>."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cUFXkl1YEC5X",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "CATEGORICAL_COLUMNS = [\n",
        "    'workclass', 'education', 'marital_status', 'occupation', 'relationship',\n",
        "    'race', 'gender', 'native_country'\n",
        "]\n",
        "CONTINUOUS_COLUMNS = [\n",
        "    'age', 'capital_gain', 'capital_loss', 'hours_per_week', 'education_num'\n",
        "]\n",
        "COLUMNS = [\n",
        "    'age', 'workclass', 'fnlwgt', 'education', 'education_num',\n",
        "    'marital_status', 'occupation', 'relationship', 'race', 'gender',\n",
        "    'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',\n",
        "    'income_bracket'\n",
        "]\n",
        "LABEL_COLUMN = 'label'\n",
        "\n",
        "# We'll find it useful to consider two groups of examples.\n",
        "GROUPS = ['private_workforce', 'non_private_workforce']\n",
        "\n",
        "\n",
        "def get_adult_data():\n",
        "  train_df_raw = pd.read_csv(\n",
        "      \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\", \n",
        "      names=COLUMNS, skipinitialspace=True)\n",
        "  test_df_raw = pd.read_csv(\n",
        "      \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test\", \n",
        "      names=COLUMNS, skipinitialspace=True, skiprows=1)\n",
        "\n",
        "  train_df_raw[LABEL_COLUMN] = (train_df_raw['income_bracket'].apply(\n",
        "      lambda x: '>50K' in x)).astype(int)\n",
        "  test_df_raw[LABEL_COLUMN] = (test_df_raw['income_bracket'].apply(\n",
        "      lambda x: '>50K' in x)).astype(int)\n",
        "\n",
        "  # Preprocessing Features\n",
        "  pd.options.mode.chained_assignment = None  # default='warn'\n",
        "\n",
        "  # Functions for preprocessing categorical and continuous columns.\n",
        "  def binarize_categorical_columns(\n",
        "      input_train_df, input_test_df, categorical_columns=[]):\n",
        "\n",
        "      def fix_columns(input_train_df, input_test_df):\n",
        "          test_df_missing_cols = set(input_train_df.columns) - set(\n",
        "              input_test_df.columns)\n",
        "          for c in test_df_missing_cols:\n",
        "              input_test_df[c] = 0\n",
        "              train_df_missing_cols = set(input_test_df.columns) - set(\n",
        "                  input_train_df.columns)\n",
        "          for c in train_df_missing_cols:\n",
        "              input_train_df[c] = 0\n",
        "              input_train_df = input_train_df[input_test_df.columns]\n",
        "          return input_train_df, input_test_df\n",
        "\n",
        "      # Binarize categorical columns.\n",
        "      binarized_train_df = pd.get_dummies(\n",
        "          input_train_df, columns=categorical_columns)\n",
        "      binarized_test_df = pd.get_dummies(\n",
        "          input_test_df, columns=categorical_columns)\n",
        "      # Make sure the train and test dataframes have the same binarized columns.\n",
        "      fixed_train_df, fixed_test_df = fix_columns(\n",
        "          binarized_train_df, binarized_test_df)\n",
        "      return fixed_train_df, fixed_test_df\n",
        "\n",
        "  def bucketize_continuous_column(input_train_df,\n",
        "                                input_test_df,\n",
        "                                continuous_column_name,\n",
        "                                num_quantiles=None,\n",
        "                                bins=None):\n",
        "      assert (num_quantiles is None or bins is None)\n",
        "      if num_quantiles is not None:\n",
        "          train_quantized, bins_quantized = pd.qcut(\n",
        "            input_train_df[continuous_column_name],\n",
        "            num_quantiles,\n",
        "            retbins=True,\n",
        "            labels=False)\n",
        "          input_train_df[continuous_column_name] = pd.cut(\n",
        "            input_train_df[continuous_column_name], bins_quantized, \n",
        "            labels=False)\n",
        "          input_test_df[continuous_column_name] = pd.cut(\n",
        "            input_test_df[continuous_column_name], bins_quantized, labels=False)\n",
        "      elif bins is not None:\n",
        "          input_train_df[continuous_column_name] = pd.cut(\n",
        "            input_train_df[continuous_column_name], bins, labels=False)\n",
        "          input_test_df[continuous_column_name] = pd.cut(\n",
        "            input_test_df[continuous_column_name], bins, labels=False)\n",
        "\n",
        "  # Filter out all columns except the ones specified.\n",
        "  train_df = (\n",
        "      train_df_raw[CATEGORICAL_COLUMNS + CONTINUOUS_COLUMNS + [LABEL_COLUMN]])\n",
        "  test_df = (\n",
        "      test_df_raw[CATEGORICAL_COLUMNS + CONTINUOUS_COLUMNS + [LABEL_COLUMN]])\n",
        "  \n",
        "  # Bucketize continuous columns.\n",
        "  bucketize_continuous_column(train_df, test_df, 'age', num_quantiles=4)\n",
        "  bucketize_continuous_column(\n",
        "      train_df, test_df, 'capital_gain', bins=[-1, 1, 4000, 10000, 100000])\n",
        "  bucketize_continuous_column(\n",
        "      train_df, test_df, 'capital_loss', bins=[-1, 1, 1800, 1950, 4500])\n",
        "  bucketize_continuous_column(\n",
        "      train_df, test_df, 'hours_per_week', bins=[0, 39, 41, 50, 100])\n",
        "  bucketize_continuous_column(\n",
        "      train_df, test_df, 'education_num', bins=[0, 8, 9, 11, 16])\n",
        "  \n",
        "  train_df, test_df = binarize_categorical_columns(\n",
        "      train_df, test_df, \n",
        "      categorical_columns=CATEGORICAL_COLUMNS + CONTINUOUS_COLUMNS)\n",
        "  feature_names = list(train_df.keys())\n",
        "  feature_names.remove(LABEL_COLUMN)\n",
        "  num_features = len(feature_names)\n",
        "\n",
        "  # Include workclass that is not private.\n",
        "  train_df[\"private_workforce\"] = train_df[\"workclass_Private\"]\n",
        "  train_df[\"non_private_workforce\"] = 1- train_df[\"workclass_Private\"]\n",
        "  test_df[\"private_workforce\"] = test_df[\"workclass_Private\"]\n",
        "  test_df[\"non_private_workforce\"] = 1 - test_df[\"workclass_Private\"]\n",
        "  \n",
        "  return train_df, test_df, feature_names"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FbxlKeAF916G",
        "colab_type": "text"
      },
      "source": [
        "We separate out the train and test features, labels and group memberships. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vRNJuPjx9z6m",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Train and test features, labels.\n",
        "train_df, test_df, feature_names = get_adult_data()\n",
        "\n",
        "x_train = np.array(train_df[feature_names])\n",
        "y_train = np.array(train_df[LABEL_COLUMN])\n",
        "\n",
        "x_test = np.array(test_df[feature_names])\n",
        "y_test = np.array(test_df[LABEL_COLUMN])\n",
        "\n",
        "# Train and test group memberships. We maintain a list of group memberships 'z'\n",
        "# where each entry is an array of boolean memberships of shape (n,).\n",
        "z_train = [np.array(train_df[g]) for g in GROUPS]\n",
        "z_test = [np.array(test_df[g]) for g in GROUPS]"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jCsw4Y5r0euQ",
        "colab_type": "text"
      },
      "source": [
        "## Add noise to training labels\n",
        "We retain 1% of the training sample for validation and in the remaining data, randomly pick 30% of the positive labels and flip them to negative."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rnQ5tzgf06BN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Split training sample into validation and training sets.\n",
        "train_indices, vali_indices = train_test_split(\n",
        "    np.arange(x_train.shape[0]), test_size=0.01, random_state=40)\n",
        "\n",
        "x_vali = x_train[vali_indices, :]\n",
        "y_vali = y_train[vali_indices]\n",
        "z_vali = [z_train[kk][vali_indices] for kk in range(len(z_train))]\n",
        "\n",
        "x_train = x_train[train_indices, :]\n",
        "y_train = y_train[train_indices]\n",
        "z_train = [z_train[kk][train_indices] for kk in range(len(z_train))]\n",
        "\n",
        "# Flip 30% of positive training labels.\n",
        "np.random.seed(123456)\n",
        "noise_indicators = np.random.rand(x_train.shape[0],) <= 0.3\n",
        "y_train[(y_train == 1) * noise_indicators] = 0"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2MrJueBP2AxX",
        "colab_type": "text"
      },
      "source": [
        "## Evaluation function\n",
        "We include a function to evaluate the G-mean metric for given labels and predicted scores: $\\text{G-mean} = 1 - \\sqrt{\\text{TPR} \\times \\text{TNR}}$."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MiRF7kJJ2bZc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def g_mean(labels, scores):  \n",
        "  tpr = np.sum((labels > 0) & (scores > 0)) / np.sum(labels > 0)\n",
        "  tnr = np.sum((labels <= 0) & (scores <= 0)) / np.sum(labels <= 0)\n",
        "  return 1 - np.sqrt(tpr * tnr)\n",
        "\n",
        "\n",
        "def print_results(linear_model, x_vali, y_vali, x_test, y_test):\n",
        "  weights, threshold = linear_model\n",
        "  scores_vali = np.dot(x_vali, weights) + threshold\n",
        "  scores_test = np.dot(x_test, weights) + threshold\n",
        "\n",
        "  print(\"Vali G-mean: %.4f\" % g_mean(y_vali, scores_vali))\n",
        "  print(\"Test G-mean: %.4f\" % g_mean(y_test, scores_test))"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w9mrCCcr4dsX",
        "colab_type": "text"
      },
      "source": [
        "## Train classifier with cross-entropy loss\n",
        "\n",
        "As our first baseline, we  train a linear model by optimizing a plain cross-entropy loss."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2DcPSadt4l0H",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def optimize_cross_entropy_loss(\n",
        "    x_train, y_train, x_vali, y_vali, learning_rate=0.1, loops=1000):\n",
        "  tf.reset_default_graph()\n",
        "\n",
        "  # Set random seed for reproducibility.\n",
        "  np.random.seed(123456)\n",
        "  random.seed(654321)\n",
        "  tf.set_random_seed(121212)  \n",
        "\n",
        "  # Linear model.\n",
        "  dimension = x_train.shape[-1]\n",
        "  weights = tf.Variable(tf.zeros(dimension, dtype=tf.float32), name=\"weights\")\n",
        "  threshold = tf.Variable(0, name=\"threshold\", dtype=tf.float32)\n",
        "\n",
        "  # Labels and predictions on train set.\n",
        "  features_tensor = tf.constant(x_train.astype(\"float32\"), name=\"features\")\n",
        "  labels_tensor = tf.constant(y_train.astype(\"float32\"), name=\"labels\")\n",
        "  predictions_tensor = tf.tensordot(\n",
        "      features_tensor, weights, axes=(1, 0)) + threshold\n",
        "\n",
        "  # Labels and predictions on vali set.\n",
        "  features_tensor_vali = tf.constant(\n",
        "      x_vali.astype(\"float32\"), name=\"features_vali\")\n",
        "  labels_tensor_vali = tf.constant(\n",
        "      y_vali.astype(\"float32\"), name=\"labels_vali\")\n",
        "  predictions_tensor_vali = tf.tensordot(\n",
        "      features_tensor_vali, weights, axes=(1, 0)) + threshold\n",
        "  \n",
        "  # Cross-entropy loss.\n",
        "  loss_tensor_train = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(\n",
        "      labels=labels_tensor, logits=predictions_tensor))\n",
        "  loss_tensor_vali = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(\n",
        "      labels=labels_tensor_vali, logits=predictions_tensor_vali))\n",
        "  \n",
        "  # Optimizer.\n",
        "  train_op = tf.train.AdagradOptimizer(learning_rate).minimize(\n",
        "      loss_tensor_train)\n",
        "  \n",
        "  # Start TF session and initialize variables.\n",
        "  session = tf.Session()\n",
        "  session.run(tf.global_variables_initializer())\n",
        "\n",
        "  # We maintain a list of objectives and model weights during training.\n",
        "  best_model = None\n",
        "  min_loss = 1e10\n",
        "\n",
        "  # Perform full gradient updates.\n",
        "  for ii in range(loops):\n",
        "    # Gradient updates.session.run(train_op)\n",
        "    session.run(train_op)\n",
        "    if (ii % 10 == 0):\n",
        "      model = [session.run(weights), session.run(threshold)]\n",
        "      loss_vali = session.run(loss_tensor_vali)\n",
        "      if loss_vali < min_loss:\n",
        "        min_loss = loss_vali\n",
        "        best_model = model\n",
        "\n",
        "  session.close()\n",
        "  return best_model"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ub5J5OAB5jZu",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "21463ae2-939e-4f64-83b4-3206871cdd49"
      },
      "source": [
        "model_ce = optimize_cross_entropy_loss(x_train, y_train, x_vali, y_vali)\n",
        "print_results(model_ce, x_vali, y_vali, x_test, y_test)"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Vali G-mean: 0.4908\n",
            "Test G-mean: 0.4396\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wUFor24R5_GZ",
        "colab_type": "text"
      },
      "source": [
        "## Post-shift previous model\n",
        "\n",
        "As our second baseline, we tune a threshold on the previously trainined model that optimizes G-mean on the validation set."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h3JH7ze16qUk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def post_shift(linear_model, x_vali, y_vali, bin_size=0.001):\n",
        "  # Compute class probabilities.\n",
        "  weights, threshold = linear_model\n",
        "  sigmoid = lambda x: 1 / (1 + np.exp(-x))\n",
        "  y_prob = sigmoid(np.dot(x_vali, weights) + threshold)\n",
        "\n",
        "  # Tune threshold on the probabilities to minimize G-mean on validation set.\n",
        "  min_gm = 2.0\n",
        "  min_threshold = -1\n",
        "  for tt in np.arange(bin_size, 1, bin_size):\n",
        "    gm = g_mean(y_vali, y_prob - tt)\n",
        "    if gm < min_gm:\n",
        "      min_gm = gm\n",
        "      min_threshold = tt\n",
        "\n",
        "  # Apply log transform to the threshold to translate it from the [0, 1] range \n",
        "  # to the real line.\n",
        "  post_shifted_threshold = (\n",
        "      threshold - np.log(min_threshold / (1 - min_threshold)))\n",
        "  \n",
        "  return (weights, post_shifted_threshold)"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bfE9gfyz7nHE",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "7396e717-0ad0-43cc-88cc-a83a2bf60bfd"
      },
      "source": [
        "model_ps = post_shift(model_ce, x_vali, y_vali)\n",
        "print_results(model_ps, x_vali, y_vali, x_test, y_test)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Vali G-mean: 0.1716\n",
            "Test G-mean: 0.1827\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nq9RN3p_9XXD",
        "colab_type": "text"
      },
      "source": [
        "## Proposed method\n",
        "We next apply our proposed method to adaptively combine surrogates on the training set to best optimize the G-mean on the validation sample. We use *four surrogate losses*, each of this is the *sigmoid* loss averaged over the positive and negative examples, evaluated separately on two groups of examples: 'private-workforce' and 'non-private-workforce'. There is no particular reason for choosing these two groups. Our approach works with any choice of surrogates and data partitioning. \n",
        "\n",
        "We begin by writing a function that computes the sigmoid surrogate for a given real-valued input. We additionally write a function that computes a surrogate tensor for a 1-dimensional tensor input. The former would be needed for gradient computation using NumPy operations and the latter would be needed for optimizing the surrogates in TensorFlow."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RWbiQ9LBj33Q",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Sigmoid surrogate functions.\n",
        "def surrogate_fn(x, sigma_scale=1.5):\n",
        "  x1 = -np.maximum(0, x * sigma_scale)\n",
        "  x2 = np.minimum(0, x * sigma_scale)\n",
        "  return np.exp(x2) / (np.exp(x1) + np.exp(x2))\n",
        "\n",
        "def surrogate_tensor_fn(x, sigma_scale=1.5): \n",
        "  return tf.math.sigmoid(-x * sigma_scale)"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uA-O13Jv0AQr",
        "colab_type": "text"
      },
      "source": [
        "We also write a function that computes the metric. This function takes an array of labels and an array of predictions at different perturbations of the model and outputs the metric value for each set of predictions."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KrWtV5EqyAhh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# G-mean metric function.\n",
        "def metric_fn(labels, predictions):\n",
        "  \"\"\"Computes metric for given labels and predictions.\n",
        "\n",
        "  Args: \n",
        "    labels: NumPy array of shape (n,), where n is the number of data points\n",
        "    predictions: NumPy array of shape (m, n), where n is the number of data \n",
        "      points and m >= 1 is the number of perturbations.\n",
        "  \n",
        "  Returns:\n",
        "    A NumPy array of metric values of shape (n,)\n",
        "  \"\"\"\n",
        "  if predictions.ndim < 2:\n",
        "    predictions = predictions.reshape(1, -1)\n",
        "  tpr = np.mean(predictions[:, y_vali > 0] > 0, axis=1)\n",
        "  tnr = np.mean(predictions[:, y_vali <= 0] <= 0, axis=1)\n",
        "  return 1.0 - np.sqrt(tpr * tnr)"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6yWV6aj3qL03",
        "colab_type": "text"
      },
      "source": [
        "The proposed algorithm runs *projected gradient descent (PGD)* over the space of surrogates. We first write a function to estimate gradients w.r.t. the surrogates for the current linear model and performs a gradient update in surrogate space. We use the linear interpolation based strategy described in Algorithm 3 in the paper for the gradient estimation step. The output is $K$-dimensional profile of updated surrogate values, where $K$ is the number of surrogates used (in this case, $K = 2 \\times \\text{number of groups} = 4$)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "71DOg6RYa0ag",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "def surrogate_gradient_update(\n",
        "    linear_model, x_train, y_train, z_train, x_vali, y_vali, z_vali,\n",
        "    surrogate_fn, metric_fn, step_size, sigma, num_perturbations):\n",
        "  \"\"\"\n",
        "  Returns updated surrogate profile after performing gradient descent update.\n",
        "  \n",
        "  Estimates gradients w.r.t. the surrogates for given linear model and performs\n",
        "  a gradient descent update. Implements lines 4-5 in Algorithm 1 in the paper \n",
        "  and uses Algorithm 3 for the gradient descent update.\n",
        "\n",
        "  Args:\n",
        "    linear_model: (weights, threshold)\n",
        "    x_train, y_train, z_train: Training features, labels, list of boolean\n",
        "      group memberships\n",
        "    x_vali, y_vali, z_vali: Validation features, labels, list of boolean\n",
        "      group memberships\n",
        "    surrogate_fn: A function that takes a scalar input and outputs a real-value\n",
        "    metric_fn: A function that takes  as input the labels and predictions at\n",
        "      different perturbations of the model, and outputs an array of metrics\n",
        "    step_size: Step size of gradient descent update\n",
        "    sigma: Perturbation parameter for gradient computation\n",
        "    num_perturbations: Number of model peturbations for gradient computation\n",
        "\n",
        "  Returns:\n",
        "    a K-dim vector containing values for the K surrogates after gradient descent\n",
        "    update (K = 2 x number of groups)\n",
        "  \"\"\"\n",
        "  model_weights, threshold = linear_model\n",
        "\n",
        "  num_train = x_train.shape[0]\n",
        "  num_vali = x_vali.shape[0]\n",
        "\n",
        "  ############ Line 4 in Algorithm 1 #######################################\n",
        "  ############ Gradient estimation using Algorithm 3 #######################\n",
        "\n",
        "  # Model dimension and number of groups.\n",
        "  dimension = model_weights.shape[0]\n",
        "  num_groups = len(z_train)\n",
        "\n",
        "  # Unperturbed predictions on train and vali.\n",
        "  predictions_train = np.dot(x_train, model_weights) + threshold \n",
        "  predictions_vali = np.dot(x_vali, model_weights) + threshold \n",
        "\n",
        "  # Perturbed predictions on train and vali sets by perturbing model weights.\n",
        "  perturbations = np.random.normal(size=(num_perturbations, dimension)) * sigma\n",
        "  perturbed_weights = model_weights.reshape(\n",
        "      1, -1) + perturbations   # shape (num_perturbations, dimension)\n",
        "  perturbed_predictions_train = np.dot(\n",
        "      perturbed_weights, x_train.T) + threshold \n",
        "  perturbed_predictions_vali = np.dot(perturbed_weights, x_vali.T) + threshold \n",
        "\n",
        "  # Differences between unperturbed and perturbed surrogates on train set \n",
        "  # for each group.\n",
        "  fpr_surrogate_diffs = np.zeros((num_perturbations, num_groups))\n",
        "  fnr_surrogate_diffs = np.zeros((num_perturbations, num_groups))\n",
        "\n",
        "  for group in range(num_groups):\n",
        "    neg_group_indices = (y_train <= 0) & (z_train[group] == 1)\n",
        "    pos_group_indices = (y_train > 0) & (z_train[group] == 1)\n",
        "\n",
        "    # Unperturbed surrogates on train set for group.\n",
        "    fpr_surrogate_train = np.mean(\n",
        "        surrogate_fn(predictions_train[neg_group_indices]))\n",
        "    fnr_surrogate_train = np.mean(\n",
        "        surrogate_fn(-predictions_train[pos_group_indices]))\n",
        "    \n",
        "    # Perturbed surrogates on train set for group.\n",
        "    perturbed_fpr_surrogate_train = np.mean(\n",
        "        surrogate_fn(perturbed_predictions_train[:, neg_group_indices]), axis=1)\n",
        "    perturbed_fnr_surrogate_train = np.mean(\n",
        "        surrogate_fn(-perturbed_predictions_train[:, pos_group_indices]), axis=1)\n",
        "    \n",
        "    # Differences between unperturbed and perturbed surrogates for group.\n",
        "    fpr_surrogate_diffs[:, group] =(\n",
        "        perturbed_fpr_surrogate_train - fpr_surrogate_train)\n",
        "    fnr_surrogate_diffs[:, group] = (\n",
        "        perturbed_fnr_surrogate_train - fnr_surrogate_train)\n",
        "\n",
        "  # Concatenate the fpr and fnr surrogate differences.\n",
        "  surrogate_diffs = np.concatenate(\n",
        "      [fpr_surrogate_diffs, fnr_surrogate_diffs], axis=1)\n",
        "  \n",
        "  # Calculate the differences in G-mean metric with and without perturbations.\n",
        "  gm_vali = metric_fn(y_vali, predictions_vali)\n",
        "  perturbed_gm_vali = metric_fn(y_vali, perturbed_predictions_vali)\n",
        "  metric_diffs = perturbed_gm_vali - gm_vali\n",
        "\n",
        "  # Fit linear regression model from surrogate diffs to metric diffs. The \n",
        "  # fitted linear coefficients gives the surrogates.\n",
        "  reg_model = sklearn_linear_model.LinearRegression()\n",
        "  reg_model.fit(surrogate_diffs, metric_diffs)\n",
        "  grad = reg_model.coef_\n",
        "\n",
        "  ############ Line 5 in Algorithm 1 #######################################\n",
        "  ############ Gradient descent update in surrogate space ##################\n",
        "  fpr_surrogate_train_new = fpr_surrogate_train - step_size * grad[:num_groups]\n",
        "  fnr_surrogate_train_new = fnr_surrogate_train - step_size * grad[num_groups:]\n",
        "  \n",
        "  return fpr_surrogate_train_new, fnr_surrogate_train_new"
      ],
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "58Gic2ymr8TS",
        "colab_type": "text"
      },
      "source": [
        "We are now ready to write the main PGD routine which calls the above routine."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "abjCMPbpvdwG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def surrogate_pgd(\n",
        "    x_train, y_train, z_train, x_vali, y_vali, z_vali, surrogate_fn, \n",
        "    surrogate_tensor_fn, metric_fn, step_size_outer=0.1, step_size_inner=1.0,\n",
        "    outer_loops=250, inner_loops=100, sigma=1.0, num_perturbations=100):\n",
        "  \"\"\"\n",
        "  Trains a linear model by running projected gradient descent in surrogate \n",
        "  space with given surrogates and metric.\n",
        "  \n",
        "  Implements Algorithm 1 in the paper, with calls to Algorithm 3 to estimate\n",
        "  gradients.\n",
        "\n",
        "  Args:\n",
        "    x_train, y_train, z_train: Training features, labels, list of boolean\n",
        "      group memberships\n",
        "    x_vali, y_vali, z_vali: Validation features, labels, list of boolean\n",
        "      group memberships\n",
        "    surrogate_fn: A function that takes a scalar input and outputs a real-value\n",
        "    surrogate_tensor_fn: A function that takes a tensor input and outputs \n",
        "      a tensor of surrogate values.\n",
        "    metric_fn: A function that takes  as input the labels and predictions at\n",
        "      different perturbations of the model, and outputs an array of metrics\n",
        "    step_size_outer: Step size for outer projected gradient descent\n",
        "    step_size_inner: Step size for inner Adagrad subroutine to implement the\n",
        "      projection step\n",
        "    outer_loops: Number of steps to run the outer projected gradient descent\n",
        "    inner_loops: Number of steps to run the inner Adagrad subroutine\n",
        "    sigma: Perturbation parameter for gradient computation\n",
        "    num_perturbations: Number of model peturbations for gradient computation\n",
        "\n",
        "  Returns:\n",
        "  \"\"\"\n",
        "  tf.reset_default_graph()\n",
        "\n",
        "  # Set random seed for reproducibility.\n",
        "  np.random.seed(123456)\n",
        "  random.seed(654321)\n",
        "  tf.set_random_seed(121212)  \n",
        "\n",
        "  # Model dimension and number of groups.\n",
        "  dimension = x_train.shape[-1]\n",
        "  num_groups = len(z_train)\n",
        " \n",
        "  # Linear model.\n",
        "  weights = tf.Variable(\n",
        "      tf.zeros((dimension,), dtype=tf.float32), name=\"weights\")\n",
        "  threshold = tf.Variable(0, name=\"threshold\", dtype=tf.float32)\n",
        "\n",
        "  # Labels, groups, predictions on train set.\n",
        "  features_tensor = tf.constant(x_train.astype(\"float32\"), name=\"features\")\n",
        "  labels_tensor = tf.constant(y_train.astype(\"float32\"), name=\"labels\")\n",
        "  groups_tensor = [tf.constant(z_train[kk].astype(\"float32\"), name=\"g_%d\" % kk) \n",
        "                   for kk in range(num_groups)]\n",
        "  predictions_tensor = tf.tensordot(\n",
        "      features_tensor, weights, axes=(1, 0)) + threshold\n",
        "  \n",
        "  # Predictions on vali set.\n",
        "  features_vali_tensor = tf.constant(\n",
        "      x_vali.astype(\"float32\"), name=\"features_vali\")\n",
        "  predictions_vali_tensor = tf.tensordot(\n",
        "      features_vali_tensor, weights, axes=(1, 0)) + threshold\n",
        "\n",
        "  # Surrogate tensors for each group.\n",
        "  fpr_surrogates = []\n",
        "  fnr_surrogates = []\n",
        "\n",
        "  for group in range(num_groups):\n",
        "    # Surrogate on fpr for group.\n",
        "    neg_predictions_group = tf.boolean_mask(\n",
        "        predictions_tensor, \n",
        "        mask=((labels_tensor <= 0) & (groups_tensor[group] > 0)))\n",
        "    fpr_surrogate_group = tf.reduce_mean(\n",
        "        surrogate_tensor_fn(-neg_predictions_group))\n",
        "    fpr_surrogates.append(fpr_surrogate_group)\n",
        "\n",
        "    # Surrogate on fpr for group.\n",
        "    pos_predictions_group = tf.boolean_mask(\n",
        "        predictions_tensor, \n",
        "        mask=((labels_tensor > 0) & (groups_tensor[group] > 0)))\n",
        "    fnr_surrogate_group = tf.reduce_mean(\n",
        "        surrogate_tensor_fn(pos_predictions_group))\n",
        "    fnr_surrogates.append(fnr_surrogate_group)\n",
        "  \n",
        "  # Place holder tensor for target surrogate values to project\n",
        "  fpr_target_tensor = tf.placeholder(tf.float32, shape=(num_groups,))\n",
        "  fnr_target_tensor = tf.placeholder(tf.float32, shape=(num_groups,))\n",
        "\n",
        "  # Projection objective function in Line 6.\n",
        "  projection_objective = 0.0\n",
        "  for group in range(num_groups):\n",
        "    projection_objective += (\n",
        "        tf.maximum(0.0, fpr_surrogates[group] - fpr_target_tensor[group]) ** 2 \n",
        "        + tf.maximum(0.0, fnr_surrogates[group] - fnr_target_tensor[group]) ** 2)\n",
        "\n",
        "  # Create train_op for projection.\n",
        "  train_op = tf.train.AdagradOptimizer(step_size_inner).minimize(\n",
        "      projection_objective)\n",
        "  \n",
        "  # Start TF session and initialize variables.\n",
        "  session = tf.Session()\n",
        "  session.run(tf.global_variables_initializer())\n",
        "\n",
        "  best_model = None\n",
        "  min_gm = 2.0\n",
        "  objectives = []\n",
        "\n",
        "  # Perform full gradient updates.\n",
        "  for ii in range(outer_loops):\n",
        "    \n",
        "    ############ Line 4-5 in Algorithm 1 #######################################\n",
        "    ######### Gradient estimation and update in surrogate space ################\n",
        "    model = [session.run(weights), session.run(threshold)]\n",
        "    fpr_surrogate_new, fnr_surrogate_new = surrogate_gradient_update(\n",
        "        model, x_train, y_train, z_train, x_vali, y_vali, z_vali,\n",
        "        surrogate_fn, metric_fn, step_size_inner, sigma, num_perturbations)\n",
        "\n",
        "    ############ Line 6 in Algorithm 1 #########################################\n",
        "    ######### Project new surrogate profile to \\mathcal{U} #####################\n",
        "    for jj in range(inner_loops):\n",
        "      session.run(train_op, \n",
        "                  feed_dict={\n",
        "                      fpr_target_tensor: fpr_surrogate_new, \n",
        "                      fnr_target_tensor: fnr_surrogate_new})\n",
        "    \n",
        "    # Evaluate G-mean on validation set once every 10 steps.\n",
        "    if (ii % 10 == 0):\n",
        "      predictions_vali = session.run(predictions_vali_tensor)\n",
        "      gm_vali = g_mean(y_vali, predictions_vali)\n",
        "      objectives.append(gm_vali)\n",
        "      if gm_vali < min_gm:\n",
        "        min_gm = gm_vali\n",
        "        best_model = [session.run(weights), session.run(threshold)]\n",
        "      print(\"Step = %d | Vali G-mean = %.3f\" % (ii, gm_vali))\n",
        "\n",
        "  session.close()\n",
        "\n",
        "  # Plot objectives with increasing steps.\n",
        "  print()\n",
        "  ff, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
        "  ax.plot(np.arange(len(objectives)) * 10, objectives) \n",
        "  ax.set_xlabel('Gradient steps')\n",
        "  ax.set_ylabel('Vali G-mean')\n",
        "  ff.tight_layout()\n",
        "  plt.show()\n",
        "  print()\n",
        "\n",
        "  return best_model"
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uyuV1tuBkqHa",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 740
        },
        "outputId": "1bc07626-fcf4-4081-9f4d-521bf6f10975"
      },
      "source": [
        "model_as = surrogate_pgd(\n",
        "    x_train, y_train, z_train, x_vali, y_vali, z_vali, surrogate_fn, \n",
        "    surrogate_tensor_fn, metric_fn)\n",
        "print_results(model_as, x_vali, y_vali, x_test, y_test)"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step = 0 | Vali G-mean = 0.227\n",
            "Step = 10 | Vali G-mean = 0.165\n",
            "Step = 20 | Vali G-mean = 0.167\n",
            "Step = 30 | Vali G-mean = 0.165\n",
            "Step = 40 | Vali G-mean = 0.165\n",
            "Step = 50 | Vali G-mean = 0.167\n",
            "Step = 60 | Vali G-mean = 0.169\n",
            "Step = 70 | Vali G-mean = 0.173\n",
            "Step = 80 | Vali G-mean = 0.171\n",
            "Step = 90 | Vali G-mean = 0.167\n",
            "Step = 100 | Vali G-mean = 0.171\n",
            "Step = 110 | Vali G-mean = 0.173\n",
            "Step = 120 | Vali G-mean = 0.167\n",
            "Step = 130 | Vali G-mean = 0.167\n",
            "Step = 140 | Vali G-mean = 0.171\n",
            "Step = 150 | Vali G-mean = 0.177\n",
            "Step = 160 | Vali G-mean = 0.181\n",
            "Step = 170 | Vali G-mean = 0.169\n",
            "Step = 180 | Vali G-mean = 0.171\n",
            "Step = 190 | Vali G-mean = 0.181\n",
            "Step = 200 | Vali G-mean = 0.173\n",
            "Step = 210 | Vali G-mean = 0.161\n",
            "Step = 220 | Vali G-mean = 0.161\n",
            "Step = 230 | Vali G-mean = 0.163\n",
            "Step = 240 | Vali G-mean = 0.158\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAADQCAYAAADcQn7hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXwV5fX48c/JHiAhrAHDGgQMoAKyqIiIggt8RbG2SqsCbm3Vtl9b2trSn1trW7evVbtoUVSqdHGhLlURKCgKAmEVCPu+BmQJW8h2fn/MBC6QZXJz5+bm3vN+vfLKzdy5M4eBHJ555nnOI6qKMcb4Ia6uAzDGRC9LMMYY31iCMcb4xhKMMcY3lmCMMb5JqOsAQqV58+baoUOHug7DmKi0cOHCvaraoqafi5oE06FDB3Jzc+s6DGOikohsDuZzdotkjPGNJRhjjG8swRhjfGMJxhjjG0swxhjfxFSC+dPMdfT5zbS6DsOYmBFTCQZg7+EiCotL6zoMY2JCTCWY9BRn2M+hwpI6jsSY2BBTCSYtJRGAQ4XFdRyJMbEhphJMeqrTgimwFowxYRFTCcZaMMaEV0wlmHQ3wRQcsxaMMeEQUwkm7UQnr7VgjAmHmEow6aluC8YSjDFhEVMJpmFSPHFij6mNCZeYSjAiQlpKIgXHrAVjTDjEVIIBpx/GWjDGhIevCUZErhaR1SKyTkQeqOD9H4vIShFZJiIzRKS9u72niMwVkRXuezeFKqb0lETrgzEmTHxLMCISD/wJuAboBowSkW6n7bYY6KOq5wFvAU+4248Ct6lqd+Bq4A8ikhGKuNJSEmygnTFh4mcLph+wTlU3qGoR8A/gusAdVHWmqh51f/wSaONuX6Oqa93XO4B8oMYFhyuSnmp9MMaEi58JJgvYGvDzNndbZe4APjp9o4j0A5KA9RW8d7eI5IpI7p49ezwFZX0wxoRPRHTyisgtQB/gydO2twb+BoxV1bLTP6eqf1XVPqrap0ULbw0c64MxJnz8XLZkO9A24Oc27rZTiMgQYDwwSFWPB2xPB/4DjFfVL0MVVHpKAoePl1BWpsTFSagOa4ypgJ8tmAVAZxHpKCJJwM3Ae4E7iEgv4EVghKrmB2xPAqYAk1T1rVAGlZ6aiCocLrLbJGP85luCUdUS4D5gKpAH/EtVV4jIoyIywt3tSaAR8KaILBGR8gT0LeBSYIy7fYmI9AxFXGlWdMqYsPF1ZUdV/RD48LRtDwa8HlLJ514HXvcjppMzqovJykj14xTGGFdEdPKG08maMNaCMcZvMZdgTlS1s7Ewxvgu5hLMiRbMcUswxvgt5hJM+coCVtXOGP/FXIKxurzGhE/MJZikhDiSE+JswqMxYRBzCQacwXbWgjHGfzGZYNJSEqwPxpgwiMkEYxMejQmPmEwwVrLBmPCIyQSTnmotGGPCITYTjLVgjAmLGE0wVjbTmHCIyFUF3Pc+FpEDIvJBqONKS0ngeEkZx0tKQ31oY0yASF1VAJxaMbf6EVv5ErJ2m2SMvyJyVQH3vRnAIT8Cs6JTxoRHxK8qUJVgVhWAU4tOGWP8ExGdvJWtKlCdYFYVACs6ZUy4ROyqAn46UXTKxsIY46uIXFXAb1aywZjwiNRVBRCR2cCbwBUisk1ErgpVbFZ0ypjwiMhVBdz3BvoVV8OkBESsBWOM3yKikzfc4uKEtOQEKzpljM9iMsGA0w9jnbzG+CtmE0x6aqL1wRjjs5hNME5NGGvBGOOnahOMiAwQkWkiskZENojIRhHZEI7g/ORUtbMWjDF+8vIU6WXgfmAhEDXTj9NTElhlLRhjfOUlwRxU1RrNEaoPnMLflmCM8ZOXBDNTRJ4E3gFODOVX1UW+RRUG6amJHD5eQlmZEhcndR2OMVHJS4Lp737vE7BNgctDH074pKUkUKZwpKjkxNQBY0xoVZtgVHVwOAIJt/SAGdWWYIzxh6epAiIyHOgOpJRvU9VH/QoqHMqTSkFhMWeRWsfRGBOdvDymfgG4CfgBIMA3gfZVfqgeKC/ZYDVhjPGPl4F2F6vqbcB+VX0EuAjo4uXgtSz6PVpE1rpfo73+gbxKs6p2xvjOS4I55n4/KiJnAcVA6+o+VJui3yLSFHgIp4O5H/CQiDTxEKtn6VaX1xjfeUkwH4hIBk7tlkXAJuDvHj5Xm6LfVwHTVHWfqu4HpgFXezinZ4F9MMYYf3h5ivRr9+Xb7hpFKap60MOxKyr63b+SfeHUot81LRheY7aygDH+89LJ20BE/p+ITHBr5rYUkf8JZRDBFv0OdlUBgJTEeJIS4qwPxhgfeblFegVnBO9F7s/bgd94+FxNi36PCCj67emzwa4qUM4mPBrjLy8JppOqPoHTuYvbZ+JlbH1tin5PBa4UkSZu5+6V7raQSreSDcb4ystAuyIRScWZHoCIdCJgTlJlVLVERMqLfscDE8uLfgO5qvoepxb9BtiiqiNUdZ+I/BonSQE8qqr7avqHq05aqrVgjPGTlwTzEPAx0FZE3gAGAGO8HLyWRb8nAhO9nCdY1oIxxl9eniJNE5FFwIU4t0Y/UtW9vkcWBukpiew4cKz6HY0xQfFaMjML5zYnCbhURG7wL6Twccpm2i2SMX6ptgUjIhOB84AVQJm7WXHqw9Rr6am2soAxfvLSB3Ohqp4+xD8qpCUnUFhcRlFJGUkJMVv/3BjfePmtmlvBHKKokJ5qa1Qb4ycvLZhJOElmF87jaQHUnaBYrwVOF2jWKLmOozEm+nhdVeBW4CtO9sFEhXSb8GiMr7wkmD3uoLioYxMejfGXlwSzWEQmA+9z6qoC9f4pkhWdMsZfXhJMKk5iuTJgW5Q8prYWjDF+8jKSd2w4AqkLVnTKGH/VaPCHO2UgaqQlJyCCTXg0xic1HV0WVUsgxsUJjZJsCVlj/FLTBPOfmuzsYVWBS0VkkYiUiMiNp733uIgsd79uqmGcnqWnJlofjDE+qVGCUdVfed3X46oCW3BKP0w+7bPDgd5AT5w6vuNEJL0msXqVlpJgfTDG+KTSBCMid4jITwN+3i4iBSJySES+5+HYXlYV2KSqyzhzAF834DNVLVHVI8AyQryqQLn0lESbKmCMT6pqwXyPUws+5atqOtACGOXh2LVZGWApcLVbcLw5MJhTa/SGTFpKAgXH7BbJGD9U9ZhaVPXrgJ/fBFDVQreEpm9U9RMR6QvMAfYAc4HSMwIUuRu4G6Bdu3ZBnSs9NZE1+YeCD9YYU6mqWjAZgT+o6m8BRCQOaO7h2J5WBqiMqj6mqj1VdSjO06s1FexTq1UFwFowxvipqgTziYhUtDzJo8AnHo5d7aoClRGReBFp5r4+D6fglZdz1lh5H4yq+nF4Y2JaVbdIPwVeEpF1OH0iAOcDucCd1R3Yy6oC7m3QFKAJcK2IPKKq3YFEYLa70kABcIuq+tLMSEtJoEzhSFEpjZK9zJwwxnhV6W+U+/RmlIhkA93dzStVdb3Xg3tYVWABJ9ejDtynEOdJku8Ci05ZgjEmtLzMRdoAbAhDLHWivGRDwbESWjeu42CMiTIxX4i2vOiUjYUxJvRiPsGcaMFYgjEm5Cq9RRKRdFUtEJGmFb3vx1KudeFkH4w9qjYm1Krqg5kM/A+wEKfAVOBMagWyfYwrbE72wVgLxphQq+op0v+43zuGL5zwO1n421owxoRaVbdIvav6oKpGRfGplMR4kuLj7BbJGB9UdYv0dBXvKXB5iGOpM+mpVrLBGD9UdYs0OJyB1KW0FCs6ZYwfPA1dFZEeOCNrU8q3qeokv4IKN2fCo7VgjAm1ahOMiDwEXIaTYD7EqVD3Oc6SslHBik4Z4w8vA+1uBK4AdrlLmJwPRNWgeqdspt0iGRNqXhLMMVUtA0rcurj5+FRdrq5YC8YYf3hJMLkikgFMwBl0twinwly1armqwBMiskJE8kTkOXFrN/jBik7FluLSMp6fsZZ1PlYyXLr1AH+etc6349cXVRX9/pOIDFDVe1T1gKq+AAwFRntZ7bGWqwpcDAzAKTTVA+gLDPL8p6qh9NREjhWXUlx6eu1xE23KypSfv72Mp6et4dkZ/iWAJ6eu5omPV7Nwc1TMqAlaVS2YNcBTIrLJbU30ClgFwIvarCqgOE+skoBknAJUuz2et8bKpwvYo+ro9/jHq3hn0XZapiUza3W+L/+p7DpYyBfr9wIw4bONIT9+fVJpglHVZ1X1IpyWw9fARBFZJSIPiUgXD8cOelUBVZ0LzAR2ul9TVTXv9P1E5G4RyRWR3D179ng5dIWsZENsmPDZBl78bAO3XtieX1/fg0OFJSzYGPoWxrtLtqMKw89tzdSVu9i090jIz1FfVNsHo6qbVfVxVe2Fs1zJ9cAZv+yhJCJnAzk41e6ygMtFZGAFsdW66DecWnTKRKd3Fm3jsQ/zGHZuKx4e0Z2BnZuTlBDHtLzQN4ynLN7O+W0zeGhENxLj4nj589htxVSbYEQkQUSuFZE3gI+A1cANHo5dm1UFRgJfquphVT3snvcij5+tscCymSb6zFydz8/eWsbFnZrxzE09iY8TGiQlcMnZzZmetzukBd9X7ihg1a5D3NAri5ZpKYzslcWbC7ey70hRyM5Rn1TVyTtURCbi3NrchbMudSdVvVlV3/Vw7KBXFcDp/B3kJrdEnNs031pNVnQqei3asp97Xl9E11ZpvHjrBSQnxJ94b0hOJlv3HWNt/uGQnW/K4m0kxAnXnn8WAHcO7EhhcRmvf7k5ZOeoT6pqwfwCZ+GzHFUdoaqT3ULgnrirAJSvKpAH/Kt8VQERGQEgIn1FZBvwTeBFEVnhfvwtYD3wFc6KBktV9f2a/uG8spIN0Wld/iFuf3UBLdOTeXVsP9Lcv+dyV+S0BGDaytDcJpWWKe8u2cFlXVvQtGESAJ0z07j8nJa8NmcThcVnrB0Y9aqa7Fjr2dK1WFWgFPhubc/v1YkEY/ORosbOg8e47eX5JMTFMen2frRISz5jn8z0FM5r05jpebu5d/DZtT7nnPV7yT90nJG9Tv0nfdfAbEZN+JJ3Fm3n2/2DW4G0vor5mrwAjewxdVQ5fLyE0RPnU1BYwqtj+9K+WcNK9x2Sk8mSrQfYc+h4rc87ZdF20lISTrSMyl2Y3ZRzsxrz0uwNlJXF1gJ/lmCA+DihUbLVhIkWz05fw5rdh3nhlgvokVX1tLkrclqiCjNX5dfqnEeLSvh4xS6Gn9ualMT4U94TEe66NJsNe48wo5bnqW8swbjSUxKsBRMFVu0qYOIXmxjVry2XdK5+CfVurdM5q3FKrR9XT12xi6NFpYzsVfFQr2E9WpGVkcqEz6J2ibEKWYJxpaUkWh9MPVdWpvxqynIapybys6vO8fQZEWFIt0w+X7u3Vp2w7yzaTlZGKn07VLgIBwnxcdx+SUfmb9rH4i37gz5PfWMJxpWeai2Y+u6tRdvI3byfB645hybuUxwvhuRkcqy4lDnu8P6ayi8o5It1exnZK4u4uMrn5N7Uty3pKQm8NDt2Bt5ZgnGlpSRaH0w9tv9IEb/7MI8+7ZtwY+8zHkxWqX92UxomxTNtZXD9I+8u2UGZwsjeVc+EaZScwHcubM9Hy3ey5eujQZ2rvrEE47I+mPrtiamrKCgs4dfX96iyFVGR5IR4BnVtwYy83UE95Xln8XbOb9OYTi0aVbvvmIs7EB8nTPwiNloxlmBc1oKpvxZt2c/f52/l9gEdyGmdHtQxhuRkkn/oOMt3HKzR51btKiBvZ0Glnbuny0xP4bqeWfxzwVb2x8D0AUswrvI+mFDOSzH+Kykt41dTltMqPYUfDfEyyb9ig7u2JE5geg1H9U5ZtP2UqQFe3DUwm2PFpbwxL/qnD1iCcaWlJFJaphwtqt/DuT/8aicj//wF/1m2s06S5apdBdw1KZenpq4Oy/kmzd3Myp0FPHRtNxole1oko0JNGibRp31TpuV574cpLVP+vWQ7g7q0oFmjM0cKV6ZrqzQGdWnBq3M2c7ykfv97q44lGFc0FJ2as24vP/rHYvJ2FnDv5EXc+MJcFoXpkWh+QSEPvL2MYc/OZkbebv44cx25m/yt5ra7oJD/m7aGQV1acHWPVrU+3pBuLcnbWcD2A8c87T93/dfsLjhebeduRUZf3J69h48zZ/3XNf5sfWIJxnVywmP97IdZseMgd/9tIR2bN2TOA1fw+xvOZcu+o9zw5zncN3kRW/f589TiaFEJz05fy2VPzeLtRdsYc3FHPv/55WRlpDJ+ynJfy5D+5j95FJWW8ciI7oSiZPOQnEwAZngcdPfO4m2kJSec+FxNXNypOQ2S4mt8S1bfWIJxnWzB1L8Es3XfUca8soC0lAReu70fTRsmcXO/dswadxk/vKIz0/N2c8XTn/LbD/M4GKLBhKVlypu5Wxn81Cyeme60IqbdP4gHr+3GWRmpPDyiO6t3H2KiT8WWZq/dw/tLd3DvZWfToXnlc41qIrtFI7KbN/Q0u/poUQkfL9/FsAqmBniRkhjPoC4tQl6PJtL4mmCCXVVARAaLyJKAr0IRud7PWMuLTtW3qnb7jhQxeuJ8jheX8trt/WjdOPXEew2TE/jx0C7MGjeYET3PYsLsDVz25Exem7OpVpPuvtzwNdc+/zk/fWsZrRqn8ub3LuIvt1xwyi/60G6ZDO2WyR+mr2Xb/tC2no6XlPLguyvo0KwB3x2UHdJjD+mWyZcbvq72P5pPVux2pgYEcXt04lw5mewuOM7y7QVBHyPS+ZZgarOqgKrOVNWeqtoTuBw4CnziV6zgjIOB+nWLdLSohNtfXcD2A8d4eUxfumSmVbhfq8YpPPXN83n/vkvIaZ3OQ++t4MH3lgf1P+cnK3bxnZfmcfBYMc/e3JMp37+40uHxD4/o7nx/b2WNz1OZ0jJl3JvL2Lj3CI9e1yOo1kNVhuRkUlyqzF5b8ajewuJSXpq9gYffX0GbJqn0q+TP7sXgc5wnV9NW7gr6GJHOzxZMbVYVCHQj8JGq+jr08WTh7/rRgikpLeMHkxezbNsBnhvVq9Jf8kA9shrzxp39+e6gbF7/cgtP1PBJzxfr9nLf5MX0yGrM1Psv5bqeVQ+Nz8pI5f6hzi3aJytq/0ukqvzq31/x/tId/OKac7i0S/B1mCvTu10GGQ0Sz+gbKS1T3l64jSue/pTf/CeP89pk8MqYvjUe1BeoaRBPruobPxNM0KsKnOZm4O8VvRGqVQWAE9XO6kMLRlUZP2U5M1bl8+h1Pbiqu/cnKCLCA1efw3f6t+Mvs9Z7Xhxs4eb93DUpl+wWDXltbF/Pj4THDujIOa3SePi9FRw5HnzyVlV+99Eq/j5/K/cO7sR3B3UK+lhVSYiP4/KuLfnv6nxKSstQVWauzmf4c7P5yZtLadowiTfu7M+k2/vRuZIWY02UP7kK9W1kpIjoTl4RaQ2ci1N28wyhWlUAICUxjsR4qRctmGemreGfuVv54eVnc8uF7Wv8eRHh19f14LqeZ/HEx6v5WzX1YlfuKGDsK/NpmZbMpDv6kdHA+0TCxPg4HhvZgx0HC3l2xtoax1ruj/9dx18/28Doi9oz7squQR/HiyHdMjlwtJhJczczasKXjH1lAUeLSnl+VC/evXcAA86uvgyE53OdeHIVna2Y4EcmVa82qwqU+xYwRVV9b1aISMSXbDh8vITn/7uWFz/dwE192nL/0OBHrsbFCU9983yOHC/hwXeX0yg5/oxSjwAb9hzmtonzaJicwOt39qdlWkqNz3VB+6aM6teWlz/fyMheWTUezv/KFxt5etoabuidxUPXhuaRdFUu7dKCpPg4Hv1gJc0aJvHIiO6M6teOpITQ/3+c3aIR2S0aMj1vN6Mv7hDy49c1P1swtVlVoNwoKrk98kOkTngsKS1j8rwtXPbkLF78dAPfvKANj43sUetftMT4OP747d5c2LEZ495cdkY/yfYDx7jlpXmowt/u6E+bJg2CPtfPrz6HxqmJjJ/yVY2eYL2Zu5VH3l/JVd0zeeIb59Wqz8OrRskJ/Ozqrvx4aBc+/dlgRl/cwZfkUm5ojvPkqj7cnteUb1etlqsKICIdcFpAn/oV4+kibcKjqjJzVT7XPDubX075io7NG/Dvewfw5DfPJyE+NH91KYnxTBjdh3OzGnPf5MV8sc55erLn0HFueWkeh46X8Nrt/Ti7ZfUzhauS0SCJXw7LYdGWA/wzd2v1HwA++monP397GQM7N+e5Ub1C9mf24s6B2fzwis61mn7g1ZBuzpOrz9bUrh8xEvn6N6aqH6pqF1XtpKqPudseVNX33NcLVLWNqjZU1Waq2j3gs5tUNUtVw7YifSQVnVqx4yC3vDyPsa8uoLi0jBduuYB/ffcierbNCPm5GiUn8OrYvmS3aMhdk3KZuTqfW1+ex66Dhbwypm+1dW29+kbvLPp3bMrvP1rF3sNVF9n+dM0efviPxfRsm3HGekbRpne7JjRtmBSVo3r9T8/1SFpyIvkFlS/Cpaps3HuE9s0aEu9TU33nwWM8NXUN7yzeRkZqIg9f241v92/vaxMdnBbGpDv68a0X5jL2lQUkxcfx0ug+9KnFOI/TiQiPjezBNc/OZswr88luXnGrSHHGhpzdMo1XxvajQVJ0/zONjxMuP6cln6zYRXFpGYlhbKn5Lbr/5mqoshZM+aPK/5u2huXbC+iamcYvh+cwKITjMPIPFfKXWet5Y94WULj70mzuuexsGqcmVv/hEGmZlsLrd/bngbe/4raL2vsyzuTslmk8eG13Xvl8I19tr7z2Sp/2TfnDzT3D+uevS0NyMnlr4TZyN+3nok7N6jqckLEEE+D0PhhV5Yt1X/P0tNUs3nKAtk1T+cnQLry5cBujJ85nYOfm/OKaHLqdFVyRI3CG+r/46Xpem7uJ4lLlG72z+MHlnWnbNPgO1dpo06QBr9/Z39dz3Hphe24N4vF6NBvYuTlJCXFMz9ttCSZapackcrSolJLSMhZu3s/T09Ywf+M+zmqcwu9uOJcbL2hDYnwcd7sjYZ+bsZbhz8/mxt5t+MmVXWnV2Psj3IPHinlp9gYmfr6Ro8WlXN8zix9d0TlkE/dM/dIwOYEBnZoxPW83vxqe4/uj+HCxBBOgfEb1t1+ax/yN+2iRlswjI7pzc7+2p3QyJifEc8clHbmxdxv+OHMtr83ZzPvLdnD3wGzuHtSp0icPpWXKwWPFvPHlZibM3kBBYQnDz23N/w7pHJJRoaZ+G9Itk/FTlrMu/3DU/HuwBBOgWSNnhOq6/MOMH5bDLRe2JzWp8qcXjRskMn54N267qANPTF3Nc/9dx+T5W+jQrCHHiks5VlTKseJSjrrfi0pOPhAbkpPJ/UM70/2s0DyhMfXfFedkMp7lTMvbHTUJRqKlFkWfPn00Nze3VscoLC5lRl4+g7q2CGr8w+It+/nzrPUcOV5Cg6R4UhLjSU2Md14nxdMgMYHUpDguzG7GeW1C/7jZ1H8j/vg58XHClHsG1HUopxCRharap6afsxZMgJTEeIaf1zroz/dq14QJt9X478CYE4bkZPLM9DXkHyoMalpGpImeB+7GRIEhOZmowsxV0TH50RKMMREkp3UaWRmpQa8yWa6ktIx/L97Omt2HQhRZcOwWyZgIIiIM7ZbJPxZs4VhRaZUPGSpSPij0tx+uYl3+YZo0SOTt719MtodVJ/1gLRhjIsyQnEwKi8tOTDz1auWOAm55eR63v5pLaZnyuxvOJU6E2ybOJ/9QoU/RVs0SjDERpl/HpqQlJzDd4/IpuwsK+dlbSxn+/GxW7HAWoZv6v5cyql87Jo7py9eHixgzcUGdrJgRkasKuO+1E5FPRCRPRFa65RuMiXpJCXEM6tqC6Xn5VdbOOVpUwh+mr+GyJ2cxZfF27rykI5+OG8zYAR1PTI49v20Gf76lN6t3H+J7ry88ZSxWOPjWBxOwqsBQnHq8C0TkPVUNLDFfvqrAuAoOMQl4TFWniUgjqi4MbkxUGdotkw+W7eT8Rz8hrpJpA4XFpRwvKWPYua34+dXn0L5ZxdNMBndtyePfOI9xby7lp28t5Zlv9QxL4S7wt5P3xKoCACJSvqrAiQSjqpvc905JHu7yJgmqOs3dr/IaCsZEoau6t+KeyzpVuVZ6nAjDzm3lqaTGjRe0YXdBIU9OXU1megq/HJYTynAr5WeCqWhVAa/TdLsAB0TkHaAjMB14QFVPudoicjdwN0C7du1qHbAxkSIlMZ6fXX1OSI95z2Wd2F1QyF8/20DLtGTuHBjaResqEqmdvAnAQJxbp75ANs6t1ClCuaqAMdFORHjo2u5c06MVv/lPHu8t3eH7Of1MMLVZVWAbsMRdtK0E+DfQO8TxGRNz4uOEZ27qSb8OTfnJv5Ywp4aPwmsqUlcVWABkiEh5s+RyAvpujDHBS0mMZ8JtfejYvCG/+2hVUEsIe+VbH4yqlohI+aoC8cDE8lUFgFxVfU9E+gJTgCbAtSLyiKp2V9VSERkHzBCn8s5CYIJfsRoTaxo3SOS12/uRFB/na3ErK9dgjKlWsOUaIrWT1xgTBSzBGGN8YwnGGOMbSzDGGN9YgjHG+CZqniKJyB5gs4ddmwP+ji6qPYsxNCI9xkiPD07G2F5VazxcPmoSjFcikhvM47ZwshhDI9JjjPT4oPYx2i2SMcY3lmCMMb6JxQTz17oOwAOLMTQiPcZIjw9qGWPM9cEYY8InFlswxpgwsQRjjPFNTCWY6lY5qCsisklEvhKRJSKS625rKiLTRGSt+71JGOOZKCL5IrI8YFuF8YjjOfeaLhORsBQGqyTGh0Vku3sdl4jIsID3fuHGuFpErgpTjG1FZKa7KsYKEfmRuz0irmUV8YXuOqpqTHzh1KRZj1N+MwlYCnSr67jc2DYBzU/b9gROHWKAB4DHwxjPpTgVBJdXFw8wDPgIEOBCYF4dxvgwMK6Cfbu5f9/JODWe1wPxYYixNdDbfZ0GrHFjiYhrWUV8IbuOsdSCObHKgaoWAeWrHESq64DX3NevAdeH68Sq+hmwz2M81wGT1PElTiXC1nUUY2WuA/6hqsdVdSOwDuffg69UdaeqLnJfHwLycIrhR8S1rCK+ytT4OsZSgqlolYOqLmY4KfCJiCx0V0oAyFTVne7rXUBm3YR2QmXxRNp1vc+9vZgYcEghlNwAAATjSURBVFtZ5zG6Cwf2AuYRgdfytPggRNcxlhJMJLtEVXsD1wD3isilgW+q0z6NmPEEkRZPgL8AnYCewE7g6boNx+EuHPg28L+qWhD4XiRcywriC9l1jKUEU5tVDnylqtvd7/k4NYr7AbvLm8fu9/y6ixCqiCdirquq7lbVUlUtw6nhXN58r7MYRSQR55f3DVV9x90cMdeyovhCeR1jKcHUZpUD34hIQxFJK38NXAksx4lttLvbaODduonwhMrieQ+4zX0CciFwMKD5H1an9VeMxLmO4MR4s4gki0hHoDMwPwzxCPAykKeq/xfwVkRcy8riC+l19LsnPZK+cHrp1+D0fo+v63jcmLJxeuaXAivK4wKaATOAtTgrWzYNY0x/x2kaF+PcZ99RWTw4Tzz+5F7Tr4A+dRjj39wYlrm/DK0D9h/vxrgauCZMMV6Cc/uzDFjifg2LlGtZRXwhu442VcAY45tYukUyxoSZJRhjjG8swRhjfGMJxhjjG0swxhjfWIKJESKSKSKTRWSDOyVhroiMrOUxHxaRce7rR0VkSJDH6Rk4Y9fD/hkick8w5zLhZQkmBrgDqv4NfKaq2ap6Ac5AwzYV7JsQzDlU9UFVnR5kiD1xxl94lQFYgqkHLMHEhsuBIlV9oXyDqm5W1ecBRGSMiLwnIv8FZohIIxGZISKLxKlTc2LWuYiMF5E1IvI50DVg+6sicqP7+gIR+dRtKU0NGBY/S0QeF5H57jEGuqOqHwVucmuP3BQYuIh0d/df4k6+6wz8HujkbnvS3e+nIrLA3ecRd1sHEVklIm+ISJ6IvCUiDdz3fu/WQVkmIk/5cM0NxNZI3lj9An4IPFPF+2NwRsOWjyhNANLd181xpuULcAHOCM8GQLq7fZy736vAjUAiMAdo4W6/CZjovp4FPO2+HgZMDzj/HyuJ7XngO+7rJCAV6MCpdWCuxClOLTj/aX6AUy+mA85I1QHufhOBcTgjaVdzsiZ1Rl3/HUXrV1DNYVO/icifcIaJF6lqX3fzNFUtr68iwG/dWd1lOFPyM4GBwBRVPeoep6K5XF2BHsA0586MeJwh/eXKJ/wtxEkA1ZkLjBeRNsA7qrrWPW6gK92vxe7PjXDmyWwBtqrqF+7213GS7R+AQuBlEfkAJyEZH1iCiQ0rgG+U/6Cq94pIcyA3YJ8jAa+/A7QALlDVYhHZBKR4PJcAK1T1okreP+5+L8XDvz9VnSwi84DhwIci8l1gQwXn/J2qvnjKRqfGyelzYVRVS0SkH3AFTqvrPpzbSBNi1gcTG/4LpIjI9wO2Nahi/8ZAvptcBgPt3e2fAdeLSKo7A/zaCj67GmghIheBUw5ARLpXE98hnJKNZxCRbGCDqj6HM+v4vAr2nwrc7tY1QUSyRKSl+1678liAbwOfu/s1VtUPgfuB86uJzwTJEkwMUKej4XpgkIhsFJH5OKUaf17JR94A+ojIV8BtwCr3OIuAf+LM/P4IpwTG6ecqwmkVPC4iS3Fm6F5cTYgzgW4VdfIC3wKWi8gSnFuvSar6NfCFiCwXkSdV9RNgMjDXjfktTiag1ThFvPKAJjjFlNKAD0RkGfA58ONq4jNBstnUJmq5t0gfqGqPOg4lZlkLxhjjG2vBGGN8Yy0YY4xvLMEYY3xjCcYY4xtLMMYY31iCMcb45v8DtP4R/uqXiFQAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 288x216 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\n",
            "Vali G-mean: 0.1584\n",
            "Test G-mean: 0.1777\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kJEBGRWNJ53z",
        "colab_type": "text"
      },
      "source": [
        "The proposed method yields a better G-mean value on both the validation and the test sets compared to the two baselines."
      ]
    }
  ]
}